import seaborn as sns
import matplotlib
import matplotlib.pyplot as plt
matplotlib.use('Agg')
import pandas as pd
import numpy as np
sns.set_style("whitegrid")
plt.rcParams["font.family"] = 'DejaVu Sans'

model1_congrs_data = {
    'method': ['ConGrs'] * 5,
    'baseline': ['ConGrs (tau=0.1)', 'ConGrs (tau=0.3)', 'ConGrs (tau=0.5)', 'ConGrs (tau=0.7)', 'ConGrs (tau=0.9)'],
    'threshold': [0.1, 0.3, 0.5, 0.7, 0.9],
    'mean_factscore': [0.6802, 0.7530, 0.8433, 0.8816, 0.9028],
    'mean_num_facts': [29.518, 18.680, 12.150, 8.124, 5.250],
}

model1_asc_data = {
    'method': ['ASC'] * 5,
    'baseline': ['ASC (theta = 1)', 'ASC (theta = 2)', 'ASC (theta = 3)', 'ASC (theta = 4)', 'ASC (theta = 5)'],
    'threshold': [1, 2, 3, 4, 5],
    'mean_factscore': [0.6717, 0.7276, 0.7693, 0.8703, 0.9112],
    'mean_num_facts': [30.982, 16.968, 9.316, 4.462, 3.246]
}

model1_other_methods_data = {
    'method': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'baseline': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'threshold': [None] * 6,
    'mean_factscore': [0.6856, 0.7252, 0.7257, 0.7233, 0.7070, 0.5873],
    'mean_num_facts': [18.466, 18.378, 21.274, 18.988, 17.8932, 21.448]
}


model2_congrs_data = {
    'method': ['ConGrs'] * 5,
    'baseline': ['ConGrs (tau=0.1)', 'ConGrs (tau=0.3)', 'ConGrs (tau=0.5)', 'ConGrs (tau=0.7)', 'ConGrs (tau=0.9)'],
    'threshold': [0.1, 0.3, 0.5, 0.7, 0.9],
    'mean_factscore': [0.6751, 0.8349, 0.8768, 0.9066, 0.9337],
    'mean_num_facts': [22.014, 11.568, 7.218, 4.928, 2.844],
}

model2_asc_data = {
    'method': ['ASC'] * 5,
    'baseline': ['ASC (theta = 1)', 'ASC (theta = 2)', 'ASC (theta = 3)', 'ASC (theta = 4)', 'ASC (theta = 5)'],
    'threshold': [1, 2, 3, 4, 5],
    'mean_factscore': [0.6437, 0.7920, 0.8474, 0.8997, 0.9462],
    'mean_num_facts': [23.666, 11.204, 5.096, 3.388, 1.566]
}

model2_other_methods_data = {
    'method': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'baseline': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'threshold': [None] * 6,
    'mean_factscore': [0.6810, 0.7467, 0.7223, 0.7123, 0.6949, 0.5873],
    'mean_num_facts': [13.274, 10.542, 15.134, 11.626, 12.6708, 21.448]
}

model3_congrs_data = {
    'method': ['ConGrs'] * 5,
    'baseline': ['ConGrs (tau=0.1)', 'ConGrs (tau=0.3)', 'ConGrs (tau=0.5)', 'ConGrs (tau=0.7)', 'ConGrs (tau=0.9)'],
    'threshold': [0.1, 0.3, 0.5, 0.7, 0.9],
    'mean_factscore': [0.7370, 0.8156, 0.8812, 0.9126, 0.9437],
    'mean_num_facts': [36.258, 20.308, 12.412, 7.660, 4.038],
}

model3_asc_data = {
    'method': ['ASC'] * 5,
    'baseline': ['ASC (theta = 1)', 'ASC (theta = 2)', 'ASC (theta = 3)', 'ASC (theta = 4)', 'ASC (theta = 5)'],
    'threshold': [1, 2, 3, 4, 5],
    'mean_factscore': [0.7151, 0.7516, 0.8580, 0.9201, 0.9666],
    'mean_num_facts': [40.498, 15.918, 6.270, 2.454, 0.808]
}

model3_other_methods_data = {
    'method': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'baseline': ['Greedy', 'Shortest', 'LM consensus', 'MBR', 'Mean of m', 'QwQ 32B'],
    'threshold': [None] * 6,
    'mean_factscore': [0.7425, 0.7612, 0.7660, 0.7721, 0.7507, 0.5873],
    'mean_num_facts': [22.726, 23.306, 25.886, 24.292, 23.854, 21.448]
}


fig, axes = plt.subplots(1, 3, figsize=(43.2, 12))

# https://colorbrewer2.org
colors = {
    'ConGrs': '#bc80bd',      # Light teal
    'ASC': '#8dd3c7',         # Dark purple
    'Greedy': '#bebada',      # Light purple
    'Shortest': '#fb8072',    # Light red
    'LM consensus': '#fccde5', # Light pink
    'MBR': '#fdb462',         # Light orange
    'Mean of m': '#b3de69',   # Light green
    'QwQ 32B': '#80b1d3'      # Light blue
}

internal_name_to_display_name = {
    'CONGRS (0.1)': 'ConGrs',
    'CONGRS (0.3)': 'ConGrs',
    'CONGRS (0.5)': 'ConGrs',
    'CONGRS (0.7)': 'ConGrs',
    'CONGRS (0.9)': 'ConGrs',
    'Temp 0': 'Greedy',
    'Short Resp': 'Shortest',
    'LLM Cons w Abs': 'LM consensus',
    'MBR': 'MBR',
    'Mean of N': 'Mean of m',
    'Qwen QWQ 32B': 'QwQ 32B',
    'ASC (Threshold = 1)': 'ASC',
    'ASC (Threshold = 2)': 'ASC',
    'ASC (Threshold = 3)': 'ASC',
    'ASC (Threshold = 4)': 'ASC',
    'ASC (Threshold = 5)': 'ASC',
}


internal_name_to_zorder = {
    'CONGRS (0.1)': 6,
    'CONGRS (0.3)': 6,
    'CONGRS (0.5)': 6,
    'CONGRS (0.7)': 6,
    'CONGRS (0.9)': 6,
    'Temp 0': 5,
    'Short Resp': 5,
    'LLM Cons w Abs': 5,
    'MBR': 5,
    'Mean of N': 5,
    'Qwen QWQ 32B': 5,
    'ASC (Threshold = 1)': 4,
    'ASC (Threshold = 2)': 4,
    'ASC (Threshold = 3)': 4,
    'ASC (Threshold = 4)': 4,
    'ASC (Threshold = 5)': 4,
}

marker_map = {'Greedy': '^', 'Shortest': 'v', 'LM consensus': 'D', 
              'MBR': 'p', 'Mean of m': 'h', 'QwQ 32B': 'X',
              'ASC': 's', 'ConGrs': 'o'}


def set_xaxis_limits_from_max_value(ax, df_congrs, df_asc, df_others):
    max_val = max(df_congrs['mean_num_facts'].max(), 
                  df_asc['mean_num_facts'].max(), 
                  df_others['mean_num_facts'].max())
    max_val_rounded = math.ceil(max_val)
    ax.set_xlim(0, max_val_rounded)

# custom function with error bars
def create_subplot_with_errorbars(ax, subplot_idx, model_name, congrs_data, asc_data, other_methods_data, csv_file):
 
    df_congrs = pd.DataFrame(congrs_data)
    df_asc = pd.DataFrame(asc_data)
    df_others = pd.DataFrame(other_methods_data)
    df_all = pd.concat([df_congrs, df_asc, df_others], ignore_index=True)
    

    congrs_df = df_all[df_all['method'] == 'ConGrs'].sort_values('threshold')
    ax.plot(congrs_df['mean_num_facts'], congrs_df['mean_factscore'], 
            color=colors['ConGrs'], linewidth=8, alpha=0.9, 
            linestyle='-', zorder=4)
    

    asc_df = df_all[df_all['method'] == 'ASC'].sort_values('threshold')
    ax.plot(asc_df['mean_num_facts'], asc_df['mean_factscore'], 
            color=colors['ASC'], linewidth=8, alpha=0.9, 
            linestyle='-', zorder=3)
    

    for method in df_all['method'].unique():
        method_data = df_all[df_all['method'] == method]
        
        if method == 'ConGrs':
            for _, row in method_data.iterrows():
                ax.annotate(rf'$\tau={row["threshold"]}$', 
                            (row['mean_num_facts'], row['mean_factscore']),
                            xytext=(14, 14), textcoords='offset points',
                            fontsize=24, fontweight='bold', alpha=0.9,
                            bbox=dict(boxstyle="round,pad=0.1",
                                        fc='#bc80bd', lw=0, alpha=0.15))
        
        elif method == 'ASC':        
            for _, row in method_data.iterrows():
                threshold_label = rf'$\Theta={int(row["threshold"])}$'
                if row['threshold'] == 1:
                    ax.annotate(threshold_label, 
                               (row['mean_num_facts'], row['mean_factscore']),
                               ha='center', va='top',
                               xytext=(0, -20), textcoords='offset points',
                               fontsize=24, fontweight='bold', alpha=0.9,
                               bbox=dict(boxstyle="round,pad=0.1",
                                         fc='#8dd3c7', lw=0, alpha=0.15))
                elif row['threshold'] == 5:
                    ax.annotate(threshold_label, 
                               (row['mean_num_facts'], row['mean_factscore']),
                               xytext=(-8, -30), textcoords='offset points',
                               ha='center', va='top',
                               fontsize=24, fontweight='bold', alpha=0.9,
                               bbox=dict(boxstyle="round,pad=0.1",
                                         fc='#8dd3c7', lw=0, alpha=0.15))
                else:
                    ax.annotate(threshold_label, 
                               (row['mean_num_facts'], row['mean_factscore']),
                               xytext=(-15, -12), textcoords='offset points',
                               ha='right', va='top',
                               fontsize=24, fontweight='bold', alpha=0.9,
                               bbox=dict(boxstyle="round,pad=0.1",
                                         fc='#8dd3c7', lw=0, alpha=0.15))


    # Plot points from CSV data with error bars
    try:
        errorbar_data = pd.read_csv(csv_file)
        for method in ['CONGRS (0.1)', 'CONGRS (0.3)', 'CONGRS (0.5)', 'CONGRS (0.7)', 'CONGRS (0.9)', 
                       'Temp 0', 'Short Resp', 'LLM Cons w Abs', 'MBR', 'Mean of N', 'Qwen QWQ 32B',  
                       'ASC (Threshold = 1)', 'ASC (Threshold = 2)', 'ASC (Threshold = 3)',  
                       'ASC (Threshold = 4)',  'ASC (Threshold = 5)']:
            this_df = errorbar_data.loc[(errorbar_data['Baselines'] == method)]
            if not this_df.empty:
                ax.errorbar(this_df['Mean number of supported facts'].mean(), this_df['Mean FActScore'].mean(),
                            xerr=this_df['Mean number of supported facts'].std(), yerr=this_df['Mean FActScore'].std(),
                            ecolor='#777', elinewidth=2, zorder=4, capsize=4, capthick=2)
                ax.scatter(this_df['Mean number of supported facts'].mean(), this_df['Mean FActScore'].mean(),
                          c=colors[internal_name_to_display_name[method]], s=500,
                          edgecolors='#fff',
                          marker=marker_map[internal_name_to_display_name[method]], 
                          zorder=internal_name_to_zorder[method])
                print(f"{model_name} - {method}: {this_df['Mean number of supported facts'].mean()}, {this_df['Mean FActScore'].mean()}")
    except FileNotFoundError:
        print(f"CSV file {csv_file} not found, plotting without error bars for {model_name}")
        # If CSV not found, plot from hardcoded data as fallback
        for method in df_all['method'].unique():
            method_data = df_all[df_all['method'] == method]
            for _, row in method_data.iterrows():
                zorder_val = 6 if method == 'ConGrs' else (4 if method == 'ASC' else 5)
                ax.scatter(row['mean_num_facts'], row['mean_factscore'],
                          c=colors[method], s=500, edgecolors='#fff',
                          marker=marker_map[method], zorder=zorder_val)
    
    
    ax.set_xlabel('Mean number of supported facts', fontsize=40, fontweight='bold')
    if subplot_idx == 0:  
        ax.set_ylabel('Mean FActScore', fontsize=40, fontweight='bold', labelpad=10)
    
    ax.set_ylim(0.55, 1.0)
    
    ax.tick_params(axis='both', which='major', labelsize=30, length=16, width=2)


model_names = ['QWEN 2.5 72B', 'LLAMA 3.3 70B', 'OLMO 2 32B']  
model_data = [
    (model1_congrs_data, model1_asc_data, model1_other_methods_data),
    (model2_congrs_data, model2_asc_data, model2_other_methods_data),
    (model3_congrs_data, model3_asc_data, model3_other_methods_data)
]

csv_files = [
    'graph-data/factscore-trade-off-qwen-popqa-data.csv', 
    'graph-data/factscore-trade-off-llama-popqa-data.csv',
    'graph-data/factscore-trade-off-olmo-popqa-data.csv',
]

for i in range(3):
    congrs_data, asc_data, other_methods_data = model_data[i]
    create_subplot_with_errorbars(axes[i], i, model_names[i], congrs_data, asc_data, other_methods_data, csv_files[i])
   

l1 = matplotlib.lines.Line2D([], [], color=colors['Greedy'], marker='^', markersize=20, linestyle='None')
l2 = matplotlib.lines.Line2D([], [], color=colors['Mean of m'], marker='h', markersize=20, linestyle='None')
l3 = matplotlib.lines.Line2D([], [], color=colors['Shortest'], marker='v', markersize=20, linestyle='None')
l4 = matplotlib.lines.Line2D([], [], color=colors['LM consensus'], marker='D', markersize=20, linestyle='None')
l5 = matplotlib.lines.Line2D([], [], color=colors['MBR'], marker='p', markersize=20, linestyle='None')
l6 = matplotlib.lines.Line2D([], [], color=colors['QwQ 32B'], marker='X', markersize=20, linestyle='None')
l7 = matplotlib.lines.Line2D([], [], color=colors['ASC'], marker='s', markersize=20, linestyle='None')
l8 = matplotlib.lines.Line2D([], [], color=colors['ConGrs'], marker='o', markersize=20, linestyle='None')

# Moved legend up from -0.05 to -0.02
legend = fig.legend((l1,l2,l3,l4,l5,l6,l7,l8),
                   ('Greedy', r'Mean of $m$', 'Shortest', 'LM consensus', 'MBR', 'QwQ 32B', 'ASC', r'$\bf{ConGrs}$ $\bf{(ours)}$'),
                   bbox_to_anchor=(0.5375, -0.000269), loc='upper center', fontsize=38, ncol=8)

fig.text(0.02, 0.56, 'PopQA', fontsize=60, fontweight='bold', 
         rotation=90, va='center', ha='center')

plt.tight_layout()
plt.subplots_adjust(left=0.08, bottom=0.15)  
plt.savefig('popqa_subplots.pdf', bbox_inches='tight')